import os
import gzip
import shutil
import tempfile
from collections import Counter
import pybedtools
import urllib.parse
from numpy import *


assembly = 'hg38'

timepoints = (0, 1, 4, 12, 24, 96)

def select_libraries(dataset, timepoint):
    directory = "/osc-fs_home/mdehoon/Data/CASPARs"
    subdirectory = os.path.join(directory, dataset, "CTSS")
    filenames = os.listdir(subdirectory)
    filenames.sort()
    for filename in filenames:
        rootname, extension = os.path.splitext(filename)
        if extension != '.gz':
            continue
        rootname, extension = os.path.splitext(rootname)
        if extension != '.bed':
            continue
        rootname, extension = os.path.splitext(rootname)
        if extension != '.ctss':
            continue
        terms = rootname.split(".")
        sample_timepoint, replicate = terms[0].rsplit("_", 1)
        if dataset == 'CAGE':
            assert replicate in "ABCDEFGH"
            sample_timepoint, hr = sample_timepoint.split("_")
            assert hr == "hr"
            sample_timepoint = int(sample_timepoint)
        elif dataset == 'HiSeq':
            assert replicate in ("r1", "r2", "r3")
            assert sample_timepoint.startswith("t")
            sample_timepoint = int(sample_timepoint[1:])
            if sample_timepoint == 1 and replicate == "r3":
                # negative control library using water instead of RNA
                continue
        else:
            raise Exception("Unknown dataset %s" % dataset)
        assert sample_timepoint in timepoints
        if sample_timepoint != timepoint:
            continue
        path = os.path.join(subdirectory, filename)
        print("Reading", path)
        handle = gzip.open(path, "rt")
        lines = pybedtools.BedTool(handle)
        yield lines
        handle.close()

def find_dominant_tss(peak, counts):
    start = peak.start
    end = peak.end
    maxcount = max(counts)
    if maxcount == 0:
        return (peak.start + peak.end) // 2
    starts = []
    for index, count in enumerate(counts):
        if count == maxcount:
            starts.append(index+start)
    starts = array(starts)
    assert len(starts) > 0
    if len(starts) == 1:
        start = starts[0]
        return start
    median_start = dot(counts, arange(start, end)) / sum(counts)
    distances = abs(starts - median_start)
    mindistance = min(distances)
    starts = array([start for start, distance in zip(starts, distances) if distance == mindistance])
    if len(starts) == 1:
        start = starts[0]
        return start
    mid_start = (end - start) / 2
    distances = abs(starts - mid_start)
    index = argmin(distances)
    start = starts[index]
    return start

filename = "peaks.bed"
print("Reading", filename)
intervals = []
for line in pybedtools.BedTool(filename):
    name = "%s_%d-%d_%s" % (line.chrom, line.start, line.end, line.strand)
    fields = [line.chrom, line.start, line.end, name, "0", line.strand]
    interval = pybedtools.create_interval_from_list(fields)
    intervals.append(interval)

peaks = pybedtools.BedTool(intervals)
peaks = peaks.saveas()

counts = {}
datasets = ("CAGE", "HiSeq")

for dataset in datasets:
    counts[dataset] = {}
    for peak in peaks:
        name = peak.name
        start = peak.start
        end = peak.end
        length = end - start
        assert name not in counts
        counts[dataset][name] = zeros(length, int)
    for timepoint in timepoints:
        libraries = select_libraries(dataset, timepoint)
        for library in libraries:
            overlap = library.intersect(peaks, s=True, wa=True, wb=True)
            for line in overlap:
                fields = line.fields
                sequence = pybedtools.create_interval_from_list(fields[:6])
                peak = pybedtools.create_interval_from_list(fields[6:])
                name = peak.name
                index = sequence.start - peak.start
                counts[dataset][name][index] += float(sequence.score)

output = tempfile.NamedTemporaryFile(delete=False, mode='wt')
print("Writing %s" % output.name)
filename = "peaks.gff"
print("Reading %s" % filename)
peaks = pybedtools.BedTool(filename)
for peak in peaks:
    chromosome = peak.chrom
    start = peak.start
    end = peak.end
    strand = peak.strand
    name = "%s_%d-%d_%s" % (chromosome, start, end, strand)
    tss = find_dominant_tss(peak, counts['CAGE'][name])
    try:
        del peak.attrs["CAGE_tss"]
    except KeyError:
        pass
    peak.attrs["CAGE_tss"] = str(tss)
    tss = find_dominant_tss(peak, counts['HiSeq'][name])
    try:
        del peak.attrs["HiSeq_tss"]
    except KeyError:
        pass
    peak.attrs["HiSeq_tss"] = str(tss)
    output.write(str(peak))
output.close()

print("Moving %s to %s" % (output.name, filename))
shutil.move(output.name, filename)
